{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Crello analysis\n",
    "This notebook qualitatively analyzes learned models in Crello dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Editable parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Please edit these parameters\n",
    "ckpt_dir = \"../results/crello/ours-exp-ft/checkpoints\"\n",
    "dataset_name = \"crello\"\n",
    "db_root = \"../data/crello\"\n",
    "batch_size = 4\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import logging\n",
    "import sys\n",
    "from typing import Dict, List\n",
    "from pathlib import Path\n",
    "\n",
    "import tensorflow as tf\n",
    "from IPython.display import display, HTML\n",
    "%matplotlib inline\n",
    "\n",
    "sys.path.append(\"../src/mfp\")\n",
    "\n",
    "from mfp.data.spec import ATTRIBUTE_GROUPS, DataSpec, set_visual_default\n",
    "from mfp.helpers.retrieve import ImageRetriever, TextRetriever\n",
    "from mfp.helpers.svg_crello import SVGBuilder\n",
    "from mfp.models.mfp import MFP\n",
    "from mfp.models.architecture.mask import get_seq_mask\n",
    "from mfp.models.masking import get_initial_masks\n",
    "from util import grouper, load_model\n",
    "\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "# fix seed for debug\n",
    "tf.random.set_seed(0)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Load datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dataspec = DataSpec(dataset_name, db_root, batch_size)\n",
    "test_dataset = dataspec.make_dataset(\"test\", shuffle=False)\n",
    "\n",
    "iterator = iter(test_dataset.take(1))\n",
    "example = next(iterator)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Load pre-trained models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_columns = dataspec.make_input_columns()\n",
    "models = {\"main\": load_model(ckpt_dir, input_columns=input_columns)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Build DB for image/text retrieval and visualizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: this part takes several minutes due to building search index\n",
    "db_root = Path(db_root)\n",
    "image_db = ImageRetriever(db_root, image_path=db_root / \"images\")\n",
    "image_db.build(\"test\")\n",
    "text_db = TextRetriever(db_root, text_path=db_root / \"texts\")\n",
    "text_db.build(\"test\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Build visualizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "builders = {}\n",
    "builders[\"layout\"] = SVGBuilder(\n",
    "    max_width=128,\n",
    "    max_height=192,\n",
    "    key=\"type\",\n",
    "    preprocessor=dataspec.preprocessor,\n",
    ")\n",
    "patterns = (\n",
    "    (\"visual\", image_db, text_db),\n",
    "    (\"visual_wo_text\", image_db, None),\n",
    "    (\"visual_wo_image\", None, text_db),\n",
    ")\n",
    "\n",
    "for (name, idb, tdb) in patterns:\n",
    "    builders[name] = SVGBuilder(\n",
    "        max_width=128,\n",
    "        max_height=192,\n",
    "        key=\"color\",\n",
    "        preprocessor=dataspec.preprocessor,\n",
    "        image_db=idb,\n",
    "        text_db=tdb,\n",
    "        render_text=True,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_task = \"attr\"  # choose from: elem, pos, attr, txt, img\n",
    "column_names = {\n",
    "    \"txt\": [\"gt-layout\", \"gt-visual\", \"input\", \"pred\"],\n",
    "    \"img\": [\"gt-layout\", \"gt-visual\", \"input\", \"pred\"],\n",
    "    \"attr\": [\"gt-layout\", \"gt-visual\", \"input\", \"pred\"],\n",
    "    \"pos\": [\"gt-layout\", \"gt-visual\", \"pred-layout\", \"pred-visual\"],\n",
    "    \"elem\": [\"gt-layout\", \"gt-visual\", \"input-layout\", \"input-visual\", \"pred-layout\", \"pred-visual\"],\n",
    "}\n",
    "\n",
    "def visualize_reconstruction(\n",
    "    models: List[tf.keras.Model],\n",
    "    example: Dict,\n",
    "    dataspec: DataSpec\n",
    "):\n",
    "    svgs = []\n",
    "    items = dataspec.unbatch(example)\n",
    "    svgs.append(list(map(builders[\"layout\"], items)))\n",
    "    svgs.append(list(map(builders[\"visual\"], items)))\n",
    "    if target_task == \"txt\":\n",
    "        svgs.append(list(map(builders[\"visual_wo_text\"], items)))\n",
    "    elif target_task == \"img\":\n",
    "        svgs.append(list(map(builders[\"visual_wo_image\"], items)))\n",
    "    elif target_task == \"attr\":\n",
    "        svgs.append(list(map(builders[\"visual\"], [set_visual_default(x) for x in items])))\n",
    "\n",
    "    seq_mask = get_seq_mask(example[\"length\"])\n",
    "    mfp_masks = get_initial_masks(input_columns, seq_mask)\n",
    "\n",
    "    for key in mfp_masks.keys():\n",
    "        if not input_columns[key][\"is_sequence\"]:\n",
    "            continue\n",
    "        mask = mfp_masks[key].numpy()\n",
    "\n",
    "        if target_task == \"elem\":\n",
    "            target_indices = [0]  # hide first\n",
    "            for i in range(len(target_indices)):\n",
    "                mask[i, target_indices[i]] = True\n",
    "        else:\n",
    "            if key == \"type\":\n",
    "                continue\n",
    "            attr_groups = ATTRIBUTE_GROUPS[\"crello\"][target_task]\n",
    "            if key in attr_groups:\n",
    "                mask = seq_mask\n",
    "\n",
    "        mfp_masks[key] = tf.convert_to_tensor(mask)\n",
    "\n",
    "    if target_task == \"elem\":\n",
    "        example_copy = {}\n",
    "        for key in example.keys():\n",
    "            # note: assuming similar mask place in a batch\n",
    "            if example[key].shape[1] > 1:\n",
    "                B, S = example[key].shape[:2]\n",
    "                indices = tf.where(~mfp_masks[key][0, :])[:, 0]\n",
    "                example_copy[key] = tf.gather(\n",
    "                    example[key], indices, axis=1\n",
    "                )\n",
    "                # print(key, example_copy[key].shape)\n",
    "            else:\n",
    "                example_copy[key] = example[key]\n",
    "        example_copy[\"length\"] -= 1\n",
    "        items = dataspec.unbatch(example_copy)\n",
    "        svgs.append(list(map(builders[\"layout\"], items)))\n",
    "        svgs.append(list(map(builders[\"visual\"], items)))\n",
    "\n",
    "    for model in models:\n",
    "        pred = model(example, training=False, demo_args={\"masks\": mfp_masks})\n",
    "        for key in example:\n",
    "            if key not in pred:\n",
    "                pred[key] = example[key]\n",
    "\n",
    "        if target_task in [\"pos\", \"elem\"]:\n",
    "            svgs.append(list(map(builders[\"layout\"], dataspec.unbatch(pred))))\n",
    "        svgs.append(list(map(builders[\"visual\"], dataspec.unbatch(pred))))\n",
    "\n",
    "    return [list(grouper(row, len(column_names[target_task]))) for row in zip(*svgs)]\n",
    "\n",
    "iterator = iter(test_dataset.take(1))\n",
    "example = next(iterator)\n",
    "\n",
    "print(f\"From left to right: {','.join(column_names[target_task])}\")\n",
    "svgs = visualize_reconstruction(models.values(), example, dataspec)\n",
    "for i, row in enumerate(svgs):\n",
    "    print(i)\n",
    "    display(HTML(\"<div>%s</div>\" % \" \".join(itertools.chain.from_iterable(row))))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.12 ('.venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "b1cc4bcd870fb3eb296f14a2aa1daa467f3c14f214d3fd2c136db8d539a2d2c9"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
